import tensorflow as tf
from nets.net import Net


class ResNet(Net):

  def _expanded_conv(self, x, ksize, c_out, stride=1, expansion_size=6, padding='SAME', name='expanded_conv'):
    x_orig = x
    c_in = self._shape(x)[1]
    if expansion_size != 1:
      with tf.variable_scope(name + '_expansion'):
        x = self._conv(x, 1, int(c_in * expansion_size), stride=1, padding=padding)
        x = self._batch_norm(x)
        x = self._activation(x)
    with tf.variable_scope(name + '_depthwise'):
      x = self._depthwise_conv(x, ksize, channel_multiplier=1, stride=stride, padding=padding)
      x = self._batch_norm(x)
      x = self._activation(x)
    with tf.variable_scope(name + '_projection'):
      x = self._conv(x, 1, c_out, stride=1, padding=padding)
      x = self._batch_norm(x)
    with tf.variable_scope('SA'):
      if stride is not 1 or c_in != c_out:
        x_orig = self._conv(x_orig, 1, c_out, stride)
      x = x + x_orig
    return x

  def _conv_all(self, x, c_out, stride=1, mode='c2sp'):
    if mode=='c2':
      return self._conv(x, ksize=2, c_out=c_out, stride=stride, padding='SAME')
    elif mode=='c3':
      return self._conv(x, ksize=3, c_out=c_out, stride=stride, padding='SAME')
    elif mode=='c4':
      return self._conv(x, ksize=4, c_out=c_out, stride=stride, padding='SAME')
    elif mode=='c2sp':
      return self._conv2x2_sp(x, c_out=c_out, stride=stride)
    elif mode=='c4sp':
      return self._conv4x4_sp(x, c_out=c_out, stride=stride)
    elif mode=='shift':
      return self._shift_conv(x, c_out=c_out, stride=stride)
    elif mode=='sep':
      return self._separable_conv(x, ksize=3, c_out=c_out, stride=stride)
    return -1

  def _residual(self, x, c_out, stride=1, bottleneck=False, first=False, expansion=1, mode='c2sp'):
    c_in = self._shape(x)[1]
    x_orig = x
    # pre-activation residual block
    if first is False:
      with tf.variable_scope('S0'):
        x = self._batch_norm(x)
        x = self._activation(x)

    if bottleneck:
      with tf.variable_scope('S0'):
        x = self._conv(x, 1, int(c_out / 4))
      with tf.variable_scope('S1'):
        x = self._batch_norm(x)
        x = self._activation(x)
        # we use stride 2 in the 3x3 conv when using bottleneck following fb.resnet.torch
        x = self._conv_all(x, c_out=c_out / 4, stride=stride, mode=mode)
      with tf.variable_scope('S2'):
        x = self._batch_norm(x)
        x = self._activation(x)
        x = self._conv(x, 1, c_out)

    else:
      with tf.variable_scope('S0'):
        x = self._conv_all(x, c_out=c_out*expansion, stride=stride, mode=mode)
      with tf.variable_scope('S1'):
        x = self._batch_norm(x)
        x = self._activation(x)
        x = self._conv_all(x, c_out=c_out, stride=1, mode=mode)
    with tf.variable_scope('SA'):
      if stride is not 1 or c_in != c_out:
        x_orig = self._conv(x_orig, 1, c_out, stride)
      x = x_orig + x
    return x

  def model(self, x):

    if self._shape(x)[-1] == 32:
      print('ResNet for cifar dataset')

      num_residual = 9  # totoal layer: 6n+2 / 9n+2
      strides = [1, 2, 2]

      mode = 'c2sp'  # 'c2', 'c3', 'c4', 'c2sp', 'c4sp', 'shift', 'invert', 'sep'
      expansion = 1
      double = False

      channel = [16, 32, 64]

      if mode == 'shift':
        channel = [18, 36, 72]
        expansion = 3
      if mode == 'sep':
        channel = [30, 60, 120]
      if mode == 'invert':
        num_residual = int(num_residual * 2 / 3)
        expansion = 6

      if double:
        channel = [2 * i for i in channel]

      bottleneck = False
      if bottleneck:
        channel = [int(4 * i) for i in channel]

      with tf.variable_scope('init'):
        x = self._conv(x, 3, channel[0])

      if mode == 'invert':
        with tf.variable_scope('init'):
          x = self._batch_norm(x)
          x = self._activation(x)
        for i in range(len(strides)):
          with tf.variable_scope('U%d-0' % i):
            x = self._expanded_conv(x, ksize=3, c_out=channel[i], stride=strides[i], expansion_size=expansion)
          for j in range(1, num_residual):
            with tf.variable_scope('U%d-%d' % (i, j)):
              x = self._expanded_conv(x, ksize=3, c_out=channel[i], stride=1, expansion_size=expansion)
        with tf.variable_scope('global_avg_pool'):
          x = self._activation(x)
          x = self._pool(x, 'GLO')

      else:
        for i in range(len(strides)):
          with tf.variable_scope('U%d-0' % i):
            x = self._residual(x, channel[i], strides[i], bottleneck, expansion=expansion, mode=mode)
          for j in range(1, num_residual):
            with tf.variable_scope('U%d-%d' % (i, j)):
              x = self._residual(x, channel[i], 1, bottleneck, expansion=expansion, mode=mode)
        with tf.variable_scope('global_avg_pool'):
          x = self._batch_norm(x)
          x = self._activation(x)
          x = self._pool(x, 'GLO')

      with tf.variable_scope('logit'):
        x = self._fc(x, self.shape_y[1], name='fc', bias=False)
      return x

    elif self._shape(x)[-1] == 224:
      print('ResNet for ImageNet dataset')

      with tf.variable_scope('init'):
        x = self._conv(x, 7, 64, stride=2)
        x = self._batch_norm(x)
        x = self._activation(x)
        x = self._pool(x, type='MAX', ksize=3, stride=2)

      num_residual = [3, 4, 6, 3]
      strides = [1, 2, 2, 2]

      mode = 'c2sp'  # 'c2sp', 'c3'

      channel = [256, 512, 1024, 2048]
      # channel = [128, 256, 512, 1024]

      bottleneck = True

      for i in range(len(strides)):
        with tf.variable_scope('U%d-0' % i):
          x = self._residual(x, channel[i], strides[i], bottleneck, first=True if i == 0 else False, mode=mode)
        for j in range(1, num_residual[i]):
          with tf.variable_scope('U%d-%d' % (i, j)):
            x = self._residual(x, channel[i], 1, bottleneck, mode=mode)

      with tf.variable_scope('global_avg_pool'):
        x = self._batch_norm(x)
        x = self._activation(x)
        x = self._pool(x, 'GLO')

      with tf.variable_scope('logit'):
        x = self._fc(x, self.shape_y[1], name='fc')
      return x

    else:
      assert False, 'Unknown image size'

	  
